-
Notifications
You must be signed in to change notification settings - Fork 1
General LUT node #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
General LUT node #39
Conversation
| // Test to see if we have access to enable_gpnpu flag | ||
| const bool gpnpu_flag = session_options.enable_gpnpu; | ||
|
|
||
| const ProcessBroadcastSpanFuncs functors = gpnpu_flag ? ProcessBroadcastSpanFuncs{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ternary for if gpnpu, use MlasQLinearAddFixedPoint inside instead of the original MlasQLinearAdd which is in the else clause
| } | ||
|
|
||
| template <typename T8Bits> | ||
| Status ComputeQLinearGlobalAvgPoolFixedPoint( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is identical to ComputeQLinearGlobalAvgPool except it calls MlasQLinearGlobalAveragePoolNchwFixedPoint instead of MlasQLinearGlobalAveragePoolNchw and MlasQLinearGlobalAveragePoolNhwcFixedPoint instead of MlasQLinearGlobalAveragePoolNhwc. as discussed with Chris, this could definitely be refactored and deleted so MlasQLinearGlobalAveragePoolNchw and MlasQLinearGlobalAveragePoolNhwc have a way to determine gpnpu flag inside. however, I believe I did not do this because the flag from session options can only be accessed from the highest level, not down in MlasQLinearGlobalAveragePoolNhwc and MlasQLinearGlobalAveragePoolNchw
| bool channels_last, | ||
| concurrency::ThreadPool* tp); | ||
|
|
||
| template Status ComputeQLinearGlobalAvgPoolFixedPoint<int8_t>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
template mirroring existing structure from before
| return ComputeQLinearGlobalAvgPool(X.Data<uint8_t>(), x_scale, *(tensor_x_zero_point->Data<uint8_t>()), | ||
| Y.MutableData<uint8_t>(), y_scale, *(tensor_y_zero_point->Data<uint8_t>()), | ||
| N, C, image_size, channels_last_, tp); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if gpnpu, go to fixed point version, else original code
|
|
||
| std::vector<float> output_scales = ComputeOutputScale(a_scale, b_scale, y_scale); | ||
| std::optional<MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR> scale_bias_proc_ptr; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
define 2 additional processors for fixed point. as discussed with Chris, this could be refactored
| gemm_param.OutputProcessor = &*scale_bias_proc_ptr; | ||
| } | ||
| } | ||
| static void SetPostProcessorFixedPoint(const Tensor* y_zp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 new processors defined above get used here
| bool ZeroMode | ||
| ); | ||
|
|
||
| template<typename KernelType> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually I think this is unnecessary now since this didn't actually get used anywhere and there is no float math happening in the original MlasGemmQuantKernel
| } | ||
| } | ||
|
|
||
| template<typename DataType, bool IsScalarB> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed point version of MlasQLinearAddKernelRawHelper, does the conversion stuff for float scales
| } | ||
| } | ||
|
|
||
| void MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR_FIXEDPOINT::Process( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking back, I'm not sure what happened here because it's labeled fixed point but I still see floats in the function...
| MLAS_FLOAT32X4 ScaleVector = MlasBroadcastFloat32x4(Scale_); | ||
| #if !defined(MLAS_SSE2_INTRINSICS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are still float operations happening. need to revisit
Description
Created general LUT node
Motivation and Context
Part of ORT project